# modified from: https://github.com/openai/CLIP/blob/a9b1bf5920416aaeaec965c25dd9e8f98c864f16/clip/model.py

from typing import Tuple
from collections import OrderedDict
import math
import functools

import torch
import torch.nn as nn
import torch.nn.functional as F

from .merge import merge, unmerge
from einops import rearrange


class Adapter(nn.Module):
    def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
        super().__init__()
        self.skip_connect = skip_connect
        D_hidden_features = int(D_features * mlp_ratio)
        self.act = act_layer()
        self.D_fc1 = nn.Linear(D_features, D_hidden_features)
        self.D_fc2 = nn.Linear(D_hidden_features, D_features)
        
    def forward(self, x):
        # x is (BT, HW+1, D)
        xs = self.D_fc1(x)
        xs = self.act(xs)
        xs = self.D_fc2(xs)
        if self.skip_connect:
            x = x + xs
        else:
            x = xs
        return x


class local_Adapter(nn.Module):
    def __init__(self, in_channels, adapter_channels, kernel_size):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, adapter_channels)
        self.conv = nn.Conv3d(
            adapter_channels, adapter_channels,
            kernel_size=kernel_size,
            stride=(1, 1, 1),
            padding=tuple(x // 2 for x in kernel_size),
            groups=adapter_channels,
        )
        self.fc2 = nn.Linear(adapter_channels, in_channels)
        nn.init.constant_(self.conv.weight, 0.)
        nn.init.constant_(self.conv.bias, 0.)
        nn.init.constant_(self.fc1.bias, 0.)
        nn.init.constant_(self.fc2.bias, 0.)

    def forward(self, x, T):
        BT, L, C = x.size()
        B = BT // T
        Ca = self.conv.in_channels
        H = W = round(math.sqrt(L - 1))
        assert L - 1 == H * W
        x_id = x
        x = x[:, 1:, :]
        x = self.fc1(x)
        x = x.view(B, T, H, W, Ca).permute(0, 4, 1, 2, 3).contiguous()

        x = self.conv(x)

        x = x.permute(0, 2, 3, 4, 1).contiguous().view(BT, L - 1, Ca)
        x = self.fc2(x)
        x_id[:, 1:, :] += x
        return x_id


def scaled_dot_product_attention(query, key, value, merge_token=None, frame_scale=None, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    B, L, S = query.size(0), query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

    attn_bias = torch.zeros(L, S, dtype=query.dtype, device = query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), half("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), half("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor

    if frame_scale is not None:
        frame_scale = frame_scale.unsqueeze(-1).expand(-1, -1, attn_weight.size(-1)//merge_token).reshape(B, -1)
        frame_scale = torch.cat((torch.ones(B, 1, device=frame_scale.device, dtype=frame_scale.dtype), frame_scale), dim=1)
        attn_weight = attn_weight + attn_bias + frame_scale.to(torch.float32).log()[:, None, None, :].to(attn_weight.dtype)
    else:
        attn_weight += attn_bias

    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight.to(value.dtype) @ value

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""
    def forward(self, x: torch.Tensor):
        # orig_type = x.dtype
        # ret = super().forward(x.type(torch.float32))
        # return ret.type(orig_type)
        return super().forward(x)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, 
        n_head: int, 
        present_layer: int, 
        Block_layers: list, 
        merge_token: int,
        additional_local_adapter: bool,
        attn_mask: torch.Tensor = None):
        super().__init__()

        self.present_layer = present_layer
        self.Block_layers = Block_layers
        self.merge_token = merge_token


        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

        if self.present_layer in self.Block_layers:
            self.merge_adapter1 = nn.Linear(d_model, d_model//4)
            self.merge_adapter2 = nn.Linear(d_model//4, d_model)
            self.merge_adapter3 = nn.Linear(d_model, d_model//4)
            self.merge_adapter4 = nn.Linear(d_model//4, d_model)

        
        adapter_class = functools.partial(
        local_Adapter,
        in_channels=d_model,
        adapter_channels=d_model//2,
        kernel_size=(3, 1, 1),
        )
        self.additional_local_adapter = additional_local_adapter

        if self.additional_local_adapter == True:
            self.local_adapter_pre = adapter_class()
            self.local_adapter2 = Adapter(d_model)

        self.local_adapter_aft = adapter_class()
        self.local_adapter1 = Adapter(d_model)

    def attention(self, x: torch.Tensor, scale) -> torch.Tensor:
        B, L, C = x.size()
        H = self.attn.num_heads

        qkv = F.linear(x, weight=self.attn.in_proj_weight, bias=self.attn.in_proj_bias)
        qkv = qkv.view(B, L, H * 3, -1).permute(0, 2, 1, 3)
        q, k, v = qkv.split([H, H, H], dim=1)
        out = scaled_dot_product_attention(q, k, v, self.merge_token, scale)
        out = out.permute(0, 2, 1, 3).flatten(-2)
        out = self.attn.out_proj(out)

        return out
    def forward(self, x: torch.Tensor, T):
        ## x shape [HW+1, BT, D]
        BT, P, C = x.size()
        B = BT // T
        H = W = round(math.sqrt(P - 1))

        if self.additional_local_adapter == True:
            x = self.local_adapter_pre(x, T)
            x = self.local_adapter2(x)


        x = x + self.attention(self.ln_1(x), None)

        x_merged = x.clone()
        if self.present_layer in self.Block_layers: # Global adapter
            # Merge BTPC
            x_merged, scale, cum_indices = merge(self.merge_adapter1(x_merged.view(B, T, P, C)), self.merge_token)
            x_merged = self.merge_adapter2(x_merged)
            merge_cls = torch.mean(x_merged[:, :, :1, :], dim = 1)
            x_merged = x_merged[:, :, 1:, :]
            x_merged = x_merged.contiguous().view(B, self.merge_token * (P-1), C)
            x_merged = torch.cat((merge_cls, x_merged), dim = 1)

    
            x_merged = x_merged + self.attention(self.ln_1(x_merged), scale)
            merge_cls = x_merged[:, :1, :].unsqueeze(1).expand(-1, self.merge_token, -1, -1)
            x_merged = x_merged[:, 1:, :].view(B, self.merge_token, P-1, C)
            x_merged = torch.cat((merge_cls, x_merged), dim=2)
            x_merged = self.merge_adapter4(unmerge(self.merge_adapter3(x_merged), scale, cum_indices, T).view(B*T, P, C//4))
            

        x = self.local_adapter_aft(x, T)
        x = self.local_adapter1(x)
        if self.present_layer in self.Block_layers:
            x = x + x_merged
        
        x = x + self.mlp(self.ln_2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, width: int, 
        layers: int, 
        heads: int, 
        Block_layers: list, 
        merge_token: int,
        additional_local_adapter: bool,
        attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers


        self.Block_layers = Block_layers
        self.merge_token = merge_token
        

        self.resblocks = nn.ModuleList([
            ResidualAttentionBlock(
                d_model=width,
                n_head=heads,
                present_layer=i,
                Block_layers=Block_layers,
                merge_token = merge_token,
                additional_local_adapter = additional_local_adapter,
                attn_mask = attn_mask
            )
            for i in range(layers)
        ])

    def forward(self, x: torch.Tensor, T):
        for i, block in enumerate(self.resblocks):
            x = block(x, T)
        return x



class VisionTransformer(nn.Module):
    def __init__(self,
                 input_resolution: int,
                 patch_size: int,
                 width: int,
                 layers: int,
                 heads: int,
                 num_classes: int,
                 Block_layers: list, 
                 merge_token: int = 4,
                 additional_local_adapter: bool = False
                 ):
        super().__init__()
        self.merge_token = merge_token
        self.input_resolution = input_resolution
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
            kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(
            scale * torch.randn(
                (input_resolution // patch_size) ** 2 + 1, width
            )
        )
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads, Block_layers, merge_token, additional_local_adapter)

        self.ln_post = LayerNorm(width)
        
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(width, num_classes)
        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.constant_(self.fc.bias, 0.)


    def forward(self, x: torch.Tensor):
        B, T = x.size(0), x.size(2)
        x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.flatten(-2).permute(0, 2, 1)
        x = torch.cat([
            self.class_embedding.view(1, 1, -1).expand(x.shape[0], -1, -1), x
            ], dim=1)  # [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = self.transformer(x, T)
    
        x = x.view(B, T, x.size(1), x.size(2))
        x = x[:, :, 0, :].mean(dim=1)

        x = self.ln_post(x)

        x = self.dropout(x)
        x = self.fc(x)
        return x


def build_clip(**kwargs):
    model = VisionTransformer(
        input_resolution=224,
        patch_size=14,
        width=1024,
        layers=24,
        heads=16,
        **kwargs,
    )
    checkpoint = torch.jit.load('../ViT-L-14.pt', map_location='cpu')
    print(model.load_state_dict(checkpoint.visual.state_dict(), strict=False))
    return model